Skip to content

Conversation

@Narsil
Copy link

@Narsil Narsil commented Aug 14, 2024

In order for cuda graphs to be capturable, kernels need to be launched on specific stream.
Importantly for pytorch, it must be the same stream as other regular kernels.

This PR is just a demonstration of how to do it.

I do not have a short reproducible script, but a simple.

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
     .....
     
g.replay()

Should be enough.

If this PR is accepted in some form, I can work on updating the other kernels.

@turboderp
Copy link
Member

I actually already have this. Sorry for being a little absent on the main branch, but I've been working overtime on the dev_tp branch where I now have a fairly crude but functional tensor-parallel implementation up and running. There's a really big update in the pipeline for 0.1.9 I guess :)

Anyway, before the TP stuff I also implemented graphs per-module using stream capture. Figuring out a way to capture graphs while also recording indices for parameter updates on subsequent launches was rough, but I think I came up with a decent solution in the end. Full log for all that is here. (After that is mostly TP code.)`

I already had most operations running in C++, so there isn't the same benefit to graphs as you'd normally see in Torch where graphs eliminate some of the ridiculous CPU overhead caused by Python, but it still gives a 2-4% speedup on the device side (i.e. when GPU-bound) and it does still reduce CPU overhead somewhat.

@Narsil
Copy link
Author

Narsil commented Aug 14, 2024

Nice !

I'll then do some simple patches in the meantime for the few kernels we do use in text-generation-inference and wait eagerly for 1.9 ! :)

Cheers.

Shall I close the branch ?

@turboderp
Copy link
Member

Would this actually compile? at::cuda::getCurrentCUDAStream() returns a CUDAStream reference, not a cudaStream_t. I think you want the stream() function which gets you the corresponding CUDA handle.

@Narsil
Copy link
Author

Narsil commented Aug 14, 2024

This compiles.

This is a hotfix we've had for long in many kernels to fix the cuda graphs.

Somehow this doesn't seem to work in this case, but it could also be linked to max_dq_rows which I'm not sure I'm setting to a correct value (or something changed in the scratch buffer ?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants